package nsalt.fetch.branch

import chisel3._
import chisel3.util._

import nsalt._

import chiselFv._


// Found at:
// https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/frontend/BPU.scala#L117

class ReturnAddressStack(val entryCount: Int) extends Module with Formal {

  val io = IO(new Bundle {

    val in = new PredictPort()

    val out = new Bundle {
      val target = Output(UInt(32.W))
    }
  })

  val mem = Mem(entryCount, UInt(32.W))
  val sp = Counter(entryCount)

  // Simple return stack mechanism:
  //
  // If a CALL instruction is identified, push the instruction NEXT
  // to the CALL onto the stack. Otherwise if a RET is identified,
  // pop the return address out and send to output.

  val fb = io.in.feed

  when (fb.valid) {
    when (fb.operType === ArithLogicOperType.call) {

      mem.write(sp.value + 1.U, fb.pc + 4.U)
      sp.value := sp.value + 1.U
  
    }.elsewhen (fb.operType === ArithLogicOperType.ret) {
    
      sp.value := sp.value - 1.U

    }
  }

  val memRead = mem.read(sp.value)

  past(sp.value, 1){
    prevSp => when(RegNext(fb.valid)) {
      when (RegNext(fb.operType) === ArithLogicOperType.call) {
        assert(prevSp + 1.U === sp.value)
        assert(memRead === RegNext(fb.pc) + 4.U)
      }.elsewhen (RegNext(fb.operType) === ArithLogicOperType.ret) {
        assert(prevSp - 1.U === sp.value)
      }.otherwise {
        assert(sp.value === prevSp)
      }
    }
  }

  io.out.target := RegEnable(mem.read(sp.value), io.in.pc.valid)

  past(io.in.pc.valid, 1) {
    pastValid => when (pastValid) {
      assert(io.out.target === RegNext(memRead))
    }
  }
}
